# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Convert dataset to HDF5
    This script preprocesses a dataset and saves it (images and labels) to 
    an HDF5 file for improved I/O.
"""
from argparse import ArgumentParser

import h5py as h5
from tqdm import tqdm

import utils


def prepare_parser():
    usage = 'Parser for ImageNet HDF5 scripts.'
    parser = ArgumentParser(description=usage)
    parser.add_argument(
        '--dataset', type=str, default='I128',
        help='Which Dataset to train on, out of I128, I256, C10, C100;'
             'Append "_hdf5" to use the hdf5 version for ISLVRC (default: %(default)s)')
    parser.add_argument(
        '--data_root', type=str, default='data',
        help='Default location where data is stored (default: %(default)s)')
    parser.add_argument(
        '--batch_size', type=int, default=256,
        help='Default overall batchsize (default: %(default)s)')
    parser.add_argument(
        '--num_workers', type=int, default=16,
        help='Number of dataloader workers (default: %(default)s)')
    parser.add_argument(
        '--chunk_size', type=int, default=500,
        help='Default overall batchsize (default: %(default)s)')
    parser.add_argument(
        '--compression', action='store_true', default=False,
        help='Use LZF compression? (default: %(default)s)')
    return parser


def run(config):
    if 'hdf5' in config['dataset']:
        raise ValueError('Reading from an HDF5 file which you will probably be '
                         'about to overwrite! Override this error only if you know '
                         'what you''re doing!')
    # Get image size
    config['image_size'] = utils.imsize_dict[config['dataset']]

    # Update compression entry
    config['compression'] = 'lzf' if config['compression'] else None  # No compression; can also use 'lzf'

    # Get dataset
    kwargs = {'num_workers': config['num_workers'], 'pin_memory': False, 'drop_last': False}
    train_loader = utils.get_data_loaders(dataset=config['dataset'],
                                          batch_size=config['batch_size'],
                                          shuffle=False,
                                          data_root=config['data_root'],
                                          use_multiepoch_sampler=False,
                                          **kwargs)[0]

    # HDF5 supports chunking and compression. You may want to experiment
    # with different chunk sizes to see how it runs on your machines.
    # Chunk Size/compression     Read speed @ 256x256   Read speed @ 128x128  Filesize @ 128x128    Time to write @128x128
    # 1 / None                   20/s
    # 500 / None                 ramps up to 77/s       102/s                 61GB                  23min
    # 500 / LZF                                         8/s                   56GB                  23min
    # 1000 / None                78/s
    # 5000 / None                81/s
    # auto:(125,1,16,32) / None                         11/s                  61GB

    print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (
        config['dataset'], config['chunk_size'], config['compression']))
    # Loop over train loader
    for i, (x, y) in enumerate(tqdm(train_loader)):
        # Stick X into the range [0, 255] since it's coming from the train loader
        x = (255 * ((x + 1) / 2.0)).byte().numpy()
        # Numpyify y
        y = y.numpy()
        # If we're on the first batch, prepare the hdf5
        if i == 0:
            with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'w') as f:
                print('Producing dataset of len %d' % len(train_loader.dataset))
                imgs_dset = f.create_dataset('imgs', x.shape, dtype='uint8', maxshape=(
                    len(train_loader.dataset), 3, config['image_size'], config['image_size']),
                                             chunks=(
                                                 config['chunk_size'], 3, config['image_size'], config['image_size']),
                                             compression=config['compression'])
                print('Image chunks chosen as ' + str(imgs_dset.chunks))
                imgs_dset[...] = x
                labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(train_loader.dataset),),
                                               chunks=(config['chunk_size'],), compression=config['compression'])
                print('Label chunks chosen as ' + str(labels_dset.chunks))
                labels_dset[...] = y
        # Else append to the hdf5
        else:
            with h5.File(config['data_root'] + '/ILSVRC%i.hdf5' % config['image_size'], 'a') as f:
                f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0)
                f['imgs'][-x.shape[0]:] = x
                f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0)
                f['labels'][-y.shape[0]:] = y


def main():
    # parse command line and run
    parser = prepare_parser()
    config = vars(parser.parse_args())
    print(config)
    run(config)


if __name__ == '__main__':
    main()
